#!/usr/bin/env python3
'''
未验证

训练记录：
在笔记本上训练
20241109：训练分数431分，测试分数177.5分，继续训练
20241110：训练分数达到4776，测试分数达到9139.5分，远超linear，继续训练
20241111：训练分数达到5900+分，测试分数达到9139.5分，继续训练
20241112：训练分数达到5400+分，测试分数达到9535.0分，继续训练
20241113：由于蓝屏，训练分数位置，测试分数达到9535.0分，无提升，暂缓训练，进行play看效果
'''
import gymnasium as gym
import ptan
import numpy as np
import argparse
from tensorboardX import SummaryWriter
import os

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
import torch.nn.functional as F
import torch.optim as optim

from typing import Any
from lib import common
from collections import deque
import ale_py

gym.register_envs(ale_py)

GAMMA = 0.99
LEARNING_RATE = 5e-4
ENTROPY_BETA = 0.01
BATCH_SIZE = 128
NUM_ENVS = 50

REWARD_STEPS = 4
CLIP_GRAD = 0.5

SAVE_ITERS = 100


class TransposeObservation(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(TransposeObservation, self).__init__(env)

    def observation(self, observation):
        # 将观察从 (H, W, C) 转换为 (C, H, W)
        return observation.transpose(2, 0, 1)


class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        """For environments where the user need to press FIRE for the game to start."""
        super(FireResetEnv, self).__init__(env)
        # 以下可知，一些游戏存在FIRE的动作，并且存在FIRE动作的游戏其游戏动作执行有三个以上
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self, seed: int | None = None, options: dict[str, Any] | None = None):
        # 这里之所以尝试重置后尝试各种动作，是因为不知道哪个是FIRE，继续游戏，所以一个一个尝试
        # 如果不小心游戏结束了，则继续重置
        # 假设游戏继续游戏的按钮在前3
        self.env.reset(seed=seed, options=options)
        obs, _, done, _, info = self.env.step(1)
        if done:
            self.env.reset(seed=seed, options=options)
        obs, _, done, _, info = self.env.step(2)
        if done:
            self.env.reset(seed=seed, options=options)
        return obs, info

class AtariA2C(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(AtariA2C, self).__init__()

        # obs_action = (input_shape[2], input_shape[0], input_shape[1])
        print("obs_action: ", input_shape)
        obs_action = input_shape

        self.conv = nn.Sequential(
            nn.Conv2d(obs_action[0], 64, kernel_size=8, stride=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        conv_out_size = self._get_conv_out(obs_action)
        self.policy = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

        self.value = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        fx = x.float() / 256
        conv_out = self.conv(fx).view(fx.size()[0], -1)
        return self.policy(conv_out), self.value(conv_out)


def unpack_batch(batch, net, device='cpu'):
    """
    Convert batch into training tensors
    :param batch:
    :param net:
    :return: states variable, actions tensor, reference values variable
    """
    states = []
    actions = []
    rewards = []
    not_done_idx = [] # 非结束的游戏数据索引，该索引记录对应batch，states，actions，rewards
    last_states = [] # 记录采样中的执行动作后的状态，仅记录游戏非结束状态下的索引
    for idx, exp in enumerate(batch):
        states.append(np.asarray(exp.state))
        actions.append(int(exp.action))
        rewards.append(exp.reward)
        if exp.last_state is not None:
            not_done_idx.append(idx)
            last_states.append(np.asarray(exp.last_state))
    states_v = torch.FloatTensor(np.asarray(states)).to(device)
    actions_t = torch.LongTensor(actions).to(device)
    rewards_np = np.array(rewards, dtype=np.float32)
    if not_done_idx:
        last_states_v = torch.FloatTensor(np.asarray(last_states)).to(device)
        last_vals_v = net(last_states_v)[1]
        last_vals_np = last_vals_v.data.cpu().numpy()[:, 0]
        rewards_np[not_done_idx] += GAMMA ** REWARD_STEPS * last_vals_np

    ref_vals_v = torch.FloatTensor(rewards_np).to(device)   
    return states_v, actions_t, ref_vals_v


class RewardPenaltyWrapper(gym.Wrapper):
    def __init__(self, env, frame_penalty=-0.1, life_loss_penalty=-10):
        super(RewardPenaltyWrapper, self).__init__(env)
        self.frame_penalty = frame_penalty
        self.life_loss_penalty = life_loss_penalty
        self.previous_lives = 0

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.previous_lives = info.get('lives', 0)  # 初始生命值
        return obs, info

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)

        reward /= 100 # 缩放奖励
        
        # 处理生命减少时的惩罚
        current_lives = info.get('lives', self.previous_lives)
        if current_lives < self.previous_lives:
            reward += self.life_loss_penalty
            self.previous_lives = current_lives
        
        return obs, reward, done, truncated, info
    


def wrap_dqn(env, stack_frames=4, episodic_life=True, reward_clipping=True):
    if episodic_life:
        # 将多条生命的游戏模拟成单条生命ActorCriticAgent
        env = ptan.common.wrappers.EpisodicLifeEnv(env)
    # 增强初始化
    env = ptan.common.wrappers.NoopResetEnv(env, noop_max=30)

    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = ptan.common.wrappers.ProcessFrame84(env)
    env = ptan.common.wrappers.ImageToPyTorch(env)
    env = ptan.common.wrappers.FrameStack(env, stack_frames)
    env = RewardPenaltyWrapper(env)
    return env


def test_model(env, net, device, episodes=5):
    with torch.no_grad():
        total_reward = 0.0
        for _ in range(episodes):
            noop_action_count = 0
            pre_action = -1
            obs, _ = env.reset()
            while True:
                obs_v = ptan.agent.default_states_preprocessor([obs]).to(device)
                logits_v, _ = net(obs_v)
                probs_v = F.softmax(logits_v, dim=1)
                probs = probs_v.data.cpu().numpy()
                action = np.argmax(probs)
                if action == 0 and pre_action == action:  # Noop
                    noop_action_count += 1
                    if noop_action_count > 30:
                        break
                else:
                    noop_action_count = 0
                pre_action = action
                obs, reward, done, trunc, _ = env.step(action)
                total_reward += reward
                if done or trunc:
                    break
    return total_reward / episodes


def optimized_states_preprocessor(states):
    """
    Convert list of states into the form suitable for model.
    :param states: list of numpy arrays with states
    :return: torch.Tensor
    """
    if len(states) == 1:
        np_states = np.expand_dims(states[0], 0)
    else:
        np_states = np.asarray([np.asarray(s) for s in states])
    return torch.from_numpy(np_states)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda", default=True, action="store_true", help="Enable cuda")
    parser.add_argument("-n", "--name", default="breakout", required=False, help="Name of the run")
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    save_path = os.path.join("saves", "a2c-conv" + args.name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    envs = [wrap_dqn(gym.make("ALE/Atlantis-v5", obs_type='rgb', frameskip=4, repeat_action_probability=0.0), episodic_life=False) for _ in range(NUM_ENVS)]
    test_env = wrap_dqn(gym.make("ALE/Atlantis-v5", obs_type='rgb', frameskip=4, repeat_action_probability=0.0), episodic_life=False)
    writer = SummaryWriter(comment="-a2c-conv_" + args.name)

    net = AtariA2C(envs[0].observation_space.shape, envs[0].action_space.n).to(device)
    print(net)

    agent = ptan.agent.PolicyAgent(lambda x: net(x)[0], apply_softmax=True, device=device, preprocessor=optimized_states_preprocessor)
    exp_source = ptan.experience.ExperienceSourceFirstLast(envs, agent, gamma=GAMMA, steps_count=REWARD_STEPS)
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50000, gamma=0.9)

    batch = []
    best_reward = 0
    frame_idx = 0
    start_idx = 0

    if os.path.exists(save_path) and len(os.listdir(save_path)) > 0:
        # 增加加载模型的代码
        checkpoints = sorted(filter(lambda x: "epoch" in x, os.listdir(save_path)), key=lambda x: int(x.split('_')[2].split('.')[0]))
        checkpoint = torch.load(os.path.join(save_path, checkpoints[-1]), map_location=device, weights_only=False)
        frame_idx = checkpoint['frame_idx']
        start_idx = checkpoint['start_idx']
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print("加载模型成功")
    # 打印学习率大小
    print("Learning Rate:", scheduler.get_last_lr()[0])

    with common.RewardTracker(writer, stop_reward=10000) as tracker:
        with ptan.common.utils.TBMeanTracker(writer, batch_size=10) as tb_tracker:
            for step_idx, exp in enumerate(exp_source):
                batch.append(exp)

                new_rewards = exp_source.pop_total_rewards()
                if new_rewards:
                    if tracker.reward(new_rewards[0], step_idx + start_idx):
                        break

                if len(batch) < BATCH_SIZE:
                    continue


                states_v, actions_t, vals_ref_v = unpack_batch(batch, net, device=device)
                batch.clear()

                optimizer.zero_grad()
                logits_v, value_v = net(states_v)
                loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)

                log_prob_v = F.log_softmax(logits_v, dim=1)
                adv_v = vals_ref_v - value_v.squeeze(-1).detach()
                log_prob_actions_v = adv_v * log_prob_v[range(BATCH_SIZE), actions_t]
                loss_policy_v = -log_prob_actions_v.mean()

                prob_v = F.softmax(logits_v, dim=1)
                entropy_loss_v = ENTROPY_BETA * (prob_v * log_prob_v).sum(dim=1).mean()

                loss_policy_v.backward(retain_graph=True)
                grads = np.concatenate([p.grad.data.cpu().numpy().flatten()
                                        for p in net.parameters()
                                        if p.grad is not None])

                loss_v = entropy_loss_v + loss_value_v
                loss_v.backward()
                nn_utils.clip_grad_norm_(net.parameters(), CLIP_GRAD)
                optimizer.step()
                loss_v += loss_policy_v
                frame_idx += 1
                scheduler.step()

                if frame_idx % 200 == 0:
                    # Test the model
                    net.eval()
                    test_reward = test_model(test_env, net, device=device, episodes=2)
                    net.train()
                    print(f"Test reward: {test_reward:.2f}")
                    common.save_best_model(test_reward, net.state_dict(), save_path, "a2c-best", keep_best=10)

                if frame_idx % SAVE_ITERS == 0:
                    checkpoint = {
                        "net": net.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "frame_idx": frame_idx,
                        "start_idx": step_idx + start_idx,
                        "scheduler": scheduler.state_dict()
                    }
                    common.save_checkpoints(frame_idx, checkpoint, save_path, "a2c", keep_last=5)


                tb_tracker.track("advantage",       adv_v, step_idx + start_idx)
                tb_tracker.track("values",          value_v, step_idx + start_idx)
                tb_tracker.track("batch_rewards",   vals_ref_v, step_idx + start_idx)
                tb_tracker.track("loss_entropy",    entropy_loss_v, step_idx + start_idx)
                tb_tracker.track("loss_policy",     loss_policy_v, step_idx + start_idx)
                tb_tracker.track("loss_value",      loss_value_v, step_idx + start_idx)
                tb_tracker.track("loss_total",      loss_v, step_idx + start_idx)
                tb_tracker.track("grad_l2",         np.sqrt(np.mean(np.square(grads))), step_idx + start_idx)
                tb_tracker.track("grad_max",        np.max(np.abs(grads)), step_idx + start_idx)
                tb_tracker.track("grad_var",        np.var(grads), step_idx + start_idx)